{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "CL089-JvFAXe"
   },
   "source": [
    "# Fine-tuning for Semantic Segmentation with 🤗 Transformers\n",
    "\n",
    "In this notebook, you'll learn how to fine-tune a pretrained vision model for Semantic Segmentation on a custom dataset in PyTorch. The idea is to add a randomly initialized segmentation head on top of a pre-trained encoder, and fine-tune the model altogether on a labeled dataset. You can find an accompanying blog post [here](https://huggingface.co/blog/fine-tune-segformer). "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "2UEtIWt0sGkR"
   },
   "source": [
    "## Model\n",
    "\n",
    "This notebook is built for the [SegFormer model](https://huggingface.co/docs/transformers/model_doc/segformer#transformers.SegformerForSemanticSegmentation) and is supposed to run on any semantic segmentation dataset. You can adapt this notebook to other supported semantic segmentation models such as [MobileViT](https://huggingface.co/docs/transformers/model_doc/mobilevit)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "j9Fe4GuFtRMs"
   },
   "source": [
    "## Data augmentation\n",
    "\n",
    "This notebook leverages `torchvision`'s [`transforms` module](https://pytorch.org/vision/stable/transforms.html) for applying data augmentation. Using other augmentation libraries like `albumentations` is also [supported](https://github.com/huggingface/notebooks/blob/main/examples/image_classification_albumentations.ipynb).\n",
    "\n",
    "---\n",
    "\n",
    "Depending on the model and the GPU you are using, you might need to adjust the batch size to avoid out-of-memory errors. Set those two parameters, then the rest of the notebook should run smoothly.\n",
    "\n",
    "In this notebook, we'll fine-tune from the https://huggingface.co/nvidia/mit-b0 checkpoint, but note that there are others [available on the hub](https://huggingface.co/models?pipeline_tag=image-segmentation)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "6EGlFGE8uyBS"
   },
   "outputs": [],
   "source": [
    "model_checkpoint = \"nvidia/mit-b0\"  # pre-trained model from which to fine-tune\n",
    "batch_size = 4  # batch size for training and evaluation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "d32sBZeq_HQ2"
   },
   "source": [
    "Before we start, let's install the `datasets`, `transformers`, and `evaluate` libraries. We also install Git-LFS to upload the model checkpoints to Hub.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "vAvgC48er3Ih"
   },
   "outputs": [],
   "source": [
    "!pip -q install datasets transformers evaluate\n",
    "\n",
    "!git lfs install\n",
    "!git config --global credential.helper store"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "qSq03UMRvjRS"
   },
   "source": [
    "If you're opening this notebook locally, make sure your environment has an install from the last version of those libraries or run the `pip install` command above with the `--upgrade` flag.\n",
    "\n",
    "You can share the resulting model with the community. By pushing the model to the Hub, others can discover your model and build on top of it. You also get an automatically generated model card that documents how the model works and a widget that will allow anyone to try out the model directly in the browser. To enable this, you'll need to login to your account."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "PlSGaFpFuQSW"
   },
   "outputs": [],
   "source": [
    "from huggingface_hub import notebook_login\n",
    "\n",
    "notebook_login()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "XalxdrirGkLl"
   },
   "source": [
    "## Fine-tuning a model on a semantic segmentation task\n",
    "\n",
    "Given an image, the goal is to associate each and every pixel to a particular category (such as table). The screenshot below is taken from a [SegFormer fine-tuned on ADE20k](https://huggingface.co/nvidia/segformer-b0-finetuned-ade-512-512) - try out the inference widget!\n",
    "\n",
    "<img src=\"https://i.ibb.co/8B8xy1R/image.png\" alt=\"drawing\" width=\"600\"/>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "mcE455KaG687"
   },
   "source": [
    "### Loading the dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "RD_G2KJgG_bU"
   },
   "source": [
    "We will use the [🤗 Datasets](https://github.com/huggingface/datasets) library to download our custom dataset into a [`DatasetDict`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasetdict).\n",
    "\n",
    "We're using the Sidewalk dataset which is dataset of sidewalk images gathered in Belgium in the summer of 2021. You can learn more about the dataset [here](https://huggingface.co/datasets/segments/sidewalk-semantic). It is important to note that you will be required to share your contact information (email and username) with the repository authors."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 281,
     "referenced_widgets": [
      "6de4fc2807bc4bb0825befbbaca366c9",
      "77c817c830314013805fbd0d203ea8b8",
      "f1f9875ed07b480591c975222fc95950",
      "3a36ff18a13e46768228e7128e02002b",
      "c7d020044efb4725a24a0a84ae7409fa",
      "ad66d4abe2e2450cb7c8a573dc50d9ff",
      "3e3d572c197e464da6a3ce0cfc4cc89e",
      "59c9c5f2ef6d4625a97a4bc9bdc331b2",
      "9725bcfbf9d54824974e8e69c64ca6fd",
      "83896f5e610144adb992c63e9cbd56ff",
      "ef4a5405c1d34d5badf0a813a8ea1fa5",
      "289120fb74224df4a9d2aa2821432ba5",
      "f7e2d0f3ee6e4231b22616fbf933bf02",
      "78a551a0db66493f876fbe18030c3174",
      "f3174b816bec4c918cd3ab4b10ec0efe",
      "4fa96e02fc4b44cb9996bbef089fd74c",
      "0031937aefa0465683e219549c935dd7",
      "d1b7953db7a6462f831fb075199542ef",
      "1c4348e789704fe2b25cb8698e88ca8e",
      "bff07d8078824308b8aabf71faf16e88",
      "12a94046754c49868ad14d688c48d88f",
      "5ad48a5ce1dc426faa6e790064418c9b",
      "3de3cc234bda43c1ad0b354dbc474ef6",
      "d9f56ee4b21f4e0e8693232cbb490cdd",
      "05621203ac434c5f93d355bb6ce6db3d",
      "d134802237f84c73a2717eed6cc1ecb4",
      "afce9206edbb444a944c291d5282c217",
      "a60fbc5906394df9858d898670a7f819",
      "1ee064e37e61448cb9a34fa7e66a3e34",
      "01ecfb419b3f460b8377c79793740982",
      "43f8dc298b7941bf8ca651015dec533c",
      "32988ad67af149489bfd676c0942dbe4",
      "a687b33460fa461eafec017738c863de",
      "c5030cc8ebff4e628da62e8835965f33",
      "33ad7766b2c146afb54a3470f352b8ee",
      "915fad5434a545b086baf639459d4558",
      "79894b070903437584f7aca3052e50de",
      "ad87c3f60ac7492ab8aaa471630c41c7",
      "e74bfbec17f94789bfb8e367983aa193",
      "dccebc6c988d49c887484b78f4ee95e9",
      "d04bc0eadaf242b6a81a9ade962c4de2",
      "4ab875034a084c338994bc7909053241",
      "3e2ff151429e49f5aaffd0a940d36805",
      "0459a4d9fcbd4478a52481ac016ddd78",
      "23735d2bbcf24587859d5d36bf901fd0",
      "ca4341aa40384d228a99074543527b31",
      "974c53ab4fb94ee3a86f650bba9562c1",
      "3d40886e38454a618a8a9a756f5c458d",
      "8db4e36ccc6a4d4492b78a50c7ec2409",
      "ab102484e15d44a4b1182029637dfc30",
      "389469716c4e42c884ca47c6659d5612",
      "4d6b9ad04eab4a70bf1df3f43a7f5f94",
      "b969b475e1b149dda5b119e37d7761c7",
      "c298ba164a35417b9b1fff7a97bd7f7f",
      "c8cd775a94a944d9a0e0ebeb2c3645f8",
      "2cbb298d54ac4771be87a85cbcb5dff4",
      "1a5bc3a601a949b5ab59b4a4b776f2f0",
      "8c9c506bf6dd47fd882e8d70581e54c2",
      "8652d0f8599843e98edfe6c21de5be9c",
      "3b280d74e111402d94a18f61fe8c4478",
      "998feff5e78b491dacc43d045c73cb8c",
      "9747349d77e14ab3a47fd4c479970a0a",
      "47cf642995f44f04a56be02be4ab0a87",
      "dadc0338ea954e478246a588b2aa1dc7",
      "219ab526698c43f895ca102befa0135c",
      "ff2adb06e63e47798f7130e9630c1255",
      "f596082d8bda42458543179605fb2dc2",
      "af8c7c09c55341a6a14b4799ae2dff44",
      "396a68b4f1b947b9876a9587b6958fce",
      "a3a68016faa148158bd9f4d805863295",
      "3bea3ea92ee8481fb036fa4e07c08009",
      "0d05591c5dd74c26b9399ece8cea12b0",
      "ea01151241174b9da977485761a27463",
      "b2985292e9d24f8ab6c6523056d8371a",
      "4af03fff36ed4c569c86908f228f7954",
      "ae1bd290c1474450a169593db137974f",
      "097c5b79a85642efa389e68aa5df428c"
     ]
    },
    "id": "U10po8Q3w9ZJ",
    "outputId": "292c707c-2cae-4818-fe23-f9c71efe0f37"
   },
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "hf_dataset_identifier = \"segments/sidewalk-semantic\"\n",
    "ds = load_dataset(hf_dataset_identifier)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "eq8mwsZU2j6t"
   },
   "source": [
    "Let us also load the Mean IoU metric, which we'll use to evaluate our model both during and after training. \n",
    "\n",
    "IoU (short for Intersection over Union) tells us the amount of overlap between two sets. In our case, these sets will be the ground-truth segmentation map and the predicted segmentation map. To learn more, you can check out [this article](https://learnopencv.com/intersection-over-union-iou-in-object-detection-and-segmentation/)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 49,
     "referenced_widgets": [
      "4b1915eae45c4710b6a3f8495093956d",
      "b42d735eeb0043e7b318004de525b55b",
      "d1148b579b6e41af99f2463a799d6c90",
      "501ff9d5db294d928f6367741bdd8f37",
      "a494ba6485a14473abcff9a343e1038e",
      "6fda21b649114145af09deb02b2d1e95",
      "dca0b5e19b504879b6e821f43ccbed00",
      "e9bf0dfb0a574cceb00953779da8d5f3",
      "8de21d3ac33f47a39fbefa6395d703df",
      "2da9282150f94994ad230dd70d9ddbbb",
      "faa9ead661124af59981d5b0cb5d38b9"
     ]
    },
    "id": "8UGse36eLeeb",
    "outputId": "682ed0ad-b48b-4988-b8e2-aa25d216ae76"
   },
   "outputs": [],
   "source": [
    "import evaluate\n",
    "\n",
    "metric = evaluate.load(\"mean_iou\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "r8mTmFdlHOmN"
   },
   "source": [
    "The `ds` object itself is a `DatasetDict`, which contains one key per split (in this case, only \"train\" for a training split)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "7tjOWPQYLq4u",
    "outputId": "cf7abc33-2ce4-40c6-ee66-e8ce47a1c208"
   },
   "outputs": [],
   "source": [
    "ds"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Yt_phZQsxz5M"
   },
   "source": [
    "Here, the `features` tell us what each example is consisted of:\n",
    "\n",
    "* `pixel_values`: the actual image\n",
    "* `label`: segmentation mask\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "nfPPNjthI3u2"
   },
   "source": [
    "To access an actual element, you need to select a split first, then give an index:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 217
    },
    "id": "BujWoSgyMQlw",
    "outputId": "59c0dc81-3be5-4b38-f3bf-037133024e57"
   },
   "outputs": [],
   "source": [
    "example = ds[\"train\"][10]\n",
    "example[\"pixel_values\"].resize((200, 200))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 217
    },
    "id": "-p7R9KP0yUOa",
    "outputId": "c0487523-1ffe-4cb6-bfe7-dec0033a5d67"
   },
   "outputs": [],
   "source": [
    "example[\"label\"].resize((200, 200))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "vp-fIQzvyTHL"
   },
   "source": [
    "Each of the pixels above can be associated to a particular category. Let's load all the categories that are associated with the dataset. Let's also create an `id2label` dictionary to decode them back to strings and see what they are. The inverse `label2id` will be useful too, when we load the model later."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 49,
     "referenced_widgets": [
      "c71c1a959c43465e9fa63a94e72e0681",
      "22d86a1d282448a190ff5f81d5086f16",
      "108d9a14e9bd46088a1b3526e9607c48",
      "f53b548093704a4483f8825f2c7eb3a6",
      "2be22416cdaf4ab1a48145dcdbafd57f",
      "31288c6294644f4c890374a08c7ec1ec",
      "abd78c14e61843ef89699b2ad68d970c",
      "3dbd784ed0724b7b8f3bf3b5d3d87919",
      "3f71c41ea8524623b129bfafddf293db",
      "90d261d919a54a91bacba98abaf6b35c",
      "d82c57f3e3484f149b5a8d63c97ac582"
     ]
    },
    "id": "Op_AGsW0y5C3",
    "outputId": "59041978-a8fa-4aad-ef94-80013c58c288"
   },
   "outputs": [],
   "source": [
    "from huggingface_hub import hf_hub_download\n",
    "import json\n",
    "\n",
    "filename = \"id2label.json\"\n",
    "id2label = json.load(\n",
    "    open(hf_hub_download(hf_dataset_identifier, filename, repo_type=\"dataset\"), \"r\")\n",
    ")\n",
    "id2label = {int(k): v for k, v in id2label.items()}\n",
    "label2id = {v: k for k, v in id2label.items()}\n",
    "\n",
    "num_labels = len(id2label)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "2Rg_WraEzMb7",
    "outputId": "9cf28a2d-b433-4df4-c961-c987ba242ef6"
   },
   "outputs": [],
   "source": [
    "num_labels, list(label2id.keys())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "FmMUQVLXLj_I"
   },
   "source": [
    "**Note**: This dataset specificaly sets the 0th index as being `unlabeled`. We want to take this information into consideration while computing the loss. Specifically, we'll want to mask the pixels where the network predicted `unlabeled` and avoid computing the loss for it since it doesn't contribute to to training that much."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "7iCzXnImeWle"
   },
   "source": [
    "Let's shuffle the dataset and split the dataset in a train and test set. We'll explicitly define a random seed to use when calling `ds.shuffle()` to ensure our results are the same each time we run this cell."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "MC4OBSUret9N"
   },
   "outputs": [],
   "source": [
    "ds = ds.shuffle(seed=1)\n",
    "ds = ds[\"train\"].train_test_split(test_size=0.2)\n",
    "train_ds = ds[\"train\"]\n",
    "test_ds = ds[\"test\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "4zxoikSOjs0K"
   },
   "source": [
    "### Preprocessing the data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "WTupOU88p1lK"
   },
   "source": [
    "Before we can feed these images to our model, we need to preprocess them. \n",
    "\n",
    "Preprocessing images typically comes down to (1) resizing them to a particular size (2) normalizing the color channels (R,G,B) using a mean and standard deviation. These are referred to as **image transformations**.\n",
    "\n",
    "To make sure we (1) resize to the appropriate size (2) use the appropriate image mean and standard deviation for the model architecture we are going to use, we instantiate what is called a feature extractor with the `AutoFeatureExtractor.from_pretrained` method.\n",
    "\n",
    "This feature extractor is a minimal preprocessor that can be used to prepare images for model training and inference."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "PJJQ4s9S0Iag"
   },
   "outputs": [],
   "source": [
    "from transformers import AutoFeatureExtractor\n",
    "\n",
    "feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint)\n",
    "feature_extractor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "F7kWNDgSzzBo"
   },
   "outputs": [],
   "source": [
    "from torchvision.transforms import ColorJitter\n",
    "from transformers import SegformerFeatureExtractor\n",
    "\n",
    "feature_extractor = SegformerFeatureExtractor()\n",
    "jitter = ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.1) \n",
    "\n",
    "def train_transforms(example_batch):\n",
    "    images = [jitter(x) for x in example_batch['pixel_values']]\n",
    "    labels = [x for x in example_batch['label']]\n",
    "    inputs = feature_extractor(images, labels)\n",
    "    return inputs\n",
    "\n",
    "\n",
    "def val_transforms(example_batch):\n",
    "    images = [x for x in example_batch['pixel_values']]\n",
    "    labels = [x for x in example_batch['label']]\n",
    "    inputs = feature_extractor(images, labels)\n",
    "    return inputs\n",
    "\n",
    "\n",
    "# Set transforms\n",
    "train_ds.set_transform(train_transforms)\n",
    "test_ds.set_transform(val_transforms)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We also defined some data augmentations to make our model more resilient to different lighting conditions. We used the [`ColorJitter`](https://pytorch.org/vision/main/generated/torchvision.transforms.ColorJitter.html) function from `torchvision` to randomly change the brightness, contrast, saturation, and hue of the images in the batch. \n",
    "\n",
    "Also, notice the differences in between transformations applied to the train and test splits. We're only applying jittering to the training split and not to the test split. Data augmentation is usually a training-only step and isn't applied during evaluation. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "HOXmyPQ76Qv9"
   },
   "source": [
    "### Training the model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "0a-2YT7O6ayC"
   },
   "source": [
    "Now that our data is ready, we can download the pretrained model and fine-tune it. We will use the `SegformerForSemanticSegmentation` class. Calling the `from_pretrained` method on it will download and cache the weights for us. As the label ids and the number of labels are dataset dependent, we pass `label2id`, and `id2label` alongside the `model_checkpoint` here. This will make sure a custom segmentation head is created (with a custom number of output neurons)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "qMdkERPg3ab3"
   },
   "outputs": [],
   "source": [
    "from transformers import SegformerForSemanticSegmentation\n",
    "\n",
    "\n",
    "model = SegformerForSemanticSegmentation.from_pretrained(\n",
    "    model_checkpoint,\n",
    "    num_labels=num_labels,\n",
    "    id2label=id2label,\n",
    "    label2id=label2id,\n",
    "    ignore_mismatched_sizes=True,  # Will ensure the segmentation specific components are reinitialized.\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "L5umptad3k3g"
   },
   "source": [
    "The warning is telling us we are throwing away some weights (the weights and bias of the `decode_head` layer) and randomly initializing some other (the weights and bias of a new `decode_head` layer). This is expected in this case, because we are adding a new head for which we don't have pretrained weights, so the library warns us we should fine-tune this model before using it for inference, which is exactly what we are going to do.\n",
    "\n",
    "To fine-tune the model, we'll use Hugging Face's [Trainer API](https://huggingface.co/docs/transformers/main_classes/trainer). To use the `Trainer`, we'll need to define the training configuration and any evaluation metrics we might want to use.\n",
    "\n",
    "First, we'll set up the [`TrainingArguments`](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments). This defines all training hyperparameters, such as learning rate and the number of epochs, frequency to save the model and so on. We also specify to push the model to the hub after training (`push_to_hub=True`) and specify a model name (`hub_model_id`)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "6K0B3oqs3PR-"
   },
   "outputs": [],
   "source": [
    "from transformers import TrainingArguments\n",
    "\n",
    "epochs = 50\n",
    "lr = 0.00006\n",
    "batch_size = 2\n",
    "\n",
    "hub_model_id = \"segformer-b0-finetuned-segments-sidewalk-2\"\n",
    "\n",
    "training_args = TrainingArguments(\n",
    "    \"segformer-b0-finetuned-segments-sidewalk-outputs\",\n",
    "    learning_rate=lr,\n",
    "    num_train_epochs=epochs,\n",
    "    per_device_train_batch_size=batch_size,\n",
    "    per_device_eval_batch_size=batch_size,\n",
    "    save_total_limit=3,\n",
    "    eval_strategy=\"steps\",\n",
    "    save_strategy=\"steps\",\n",
    "    save_steps=20,\n",
    "    eval_steps=20,\n",
    "    logging_steps=1,\n",
    "    eval_accumulation_steps=5,\n",
    "    load_best_model_at_end=True,\n",
    "    push_to_hub=True,\n",
    "    hub_model_id=hub_model_id,\n",
    "    hub_strategy=\"end\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "T3gr2e6ggAn0"
   },
   "source": [
    "Next, we'll define a function that computes the evaluation metric we want to work with. Because we're doing semantic segmentation, we'll use the [mean Intersection over Union (mIoU)](https://huggingface.co/spaces/evaluate-metric/mean_iou), which is directly accessible in the [`evaluate` library](https://huggingface.co/docs/evaluate/index). IoU represents the overlap of segmentation masks. Mean IoU is the average of the IoU of all semantic classes. Take a look at [this blogpost](https://www.jeremyjordan.me/evaluating-image-segmentation-models/) for an overview of evaluation metrics for image segmentation.\n",
    "\n",
    "Because our model outputs logits with dimensions height/4 and width/4, we have to upscale them before we can compute the mIoU."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "eZeP3LdtgBj-"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "import evaluate\n",
    "\n",
    "metric = evaluate.load(\"mean_iou\")\n",
    "\n",
    "def compute_metrics(eval_pred):\n",
    "  with torch.no_grad():\n",
    "    logits, labels = eval_pred\n",
    "    logits_tensor = torch.from_numpy(logits)\n",
    "    # scale the logits to the size of the label\n",
    "    logits_tensor = nn.functional.interpolate(\n",
    "        logits_tensor,\n",
    "        size=labels.shape[-2:],\n",
    "        mode=\"bilinear\",\n",
    "        align_corners=False,\n",
    "    ).argmax(dim=1)\n",
    "\n",
    "    pred_labels = logits_tensor.detach().cpu().numpy()\n",
    "    # currently using _compute instead of compute\n",
    "    # see this issue for more info: https://github.com/huggingface/evaluate/pull/328#issuecomment-1286866576\n",
    "    metrics = metric._compute(\n",
    "            predictions=pred_labels,\n",
    "            references=labels,\n",
    "            num_labels=len(id2label),\n",
    "            ignore_index=0,\n",
    "            reduce_labels=feature_extractor.reduce_labels,\n",
    "        )\n",
    "    \n",
    "    # add per category metrics as individual key-value pairs\n",
    "    per_category_accuracy = metrics.pop(\"per_category_accuracy\").tolist()\n",
    "    per_category_iou = metrics.pop(\"per_category_iou\").tolist()\n",
    "\n",
    "    metrics.update({f\"accuracy_{id2label[i]}\": v for i, v in enumerate(per_category_accuracy)})\n",
    "    metrics.update({f\"iou_{id2label[i]}\": v for i, v in enumerate(per_category_iou)})\n",
    "    \n",
    "    return metrics\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "aki6vBMbgHJ8"
   },
   "source": [
    "Finally, we can instantiate a `Trainer` object.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "DPvhqbUVgIhl"
   },
   "outputs": [],
   "source": [
    "from transformers import Trainer\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    tokenizer=feature_extractor,\n",
    "    train_dataset=train_ds,\n",
    "    eval_dataset=test_ds,\n",
    "    compute_metrics=compute_metrics,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "y7ebFys1gO25"
   },
   "source": [
    "Notice that we're passing `feature_extractor` to the `Trainer`. This will ensure the feature extractor is also uploaded to the Hub along with the model checkpoints.\n",
    "\n",
    "Now that our trainer is set up, training is as simple as calling the train function. We don't need to worry about managing our GPU(s), the trainer will take care of that."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "MVTWgGeQgRpB"
   },
   "outputs": [],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "I96TvX-KgT56"
   },
   "source": [
    "When we're done with training, we can push our fine-tuned model to the Hub.\n",
    "\n",
    "This will also automatically create a model card with our results. We'll supply some extra information in kwargs to make the model card more complete."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "m8YuKExogV4P"
   },
   "outputs": [],
   "source": [
    "kwargs = {\n",
    "    \"tags\": [\"vision\", \"image-segmentation\"],\n",
    "    \"finetuned_from\": pretrained_model_name,\n",
    "    \"dataset\": hf_dataset_identifier,\n",
    "}\n",
    "\n",
    "trainer.push_to_hub(**kwargs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "dFzJ8Xb5g1Vi"
   },
   "source": [
    "## Inference\n",
    "\n",
    "Now comes the exciting part -- using our fine-tuned model! In this section, we'll show how you can load your model from the hub and use it for inference. \n",
    "\n",
    "However, you can also try out your model directly on the Hugging Face Hub, thanks to the cool widgets powered by the [hosted inference API](https://api-inference.huggingface.co/docs/python/html/index.html). If you pushed your model to the Hub in the previous step, you should see an inference widget on your model page. You can add default examples to the widget by defining example image URLs in your model card. See [this model card](https://huggingface.co/segments-tobias/segformer-b0-finetuned-segments-sidewalk/blob/main/README.md) as an example.\n",
    "\n",
    "<figure class=\"image table text-center m-0 w-full\">\n",
    "    <video \n",
    "        alt=\"The interactive widget of the model\"\n",
    "        style=\"max-width: 70%; margin: auto;\"\n",
    "        autoplay loop autobuffer muted playsinline\n",
    "    >\n",
    "      <source src=\"assets/56_fine_tune_segformer/widget.mp4\" poster=\"assets/56_fine_tune_segformer/widget-poster.png\" type=\"video/mp4\">\n",
    "  </video>\n",
    "</figure>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "QXY3erc6g8LP"
   },
   "source": [
    "### Use the model from the Hub\n",
    "\n",
    "We'll first load the model from the Hub using `SegformerForSemanticSegmentation.from_pretrained()`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "gfbdz9Hpg9mI"
   },
   "outputs": [],
   "source": [
    "from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation\n",
    "\n",
    "feature_extractor = SegformerFeatureExtractor.from_pretrained(model_checkpoint)\n",
    "hf_username = \"segments-tobias\"\n",
    "model = SegformerForSemanticSegmentation.from_pretrained(f\"{hf_username}/{hub_model_id}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "SQJkEqGxQwz6"
   },
   "source": [
    "Next, we'll load an image from our test dataset and its associated ground truth segmentation label."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "R57X_iNkqv6H"
   },
   "outputs": [],
   "source": [
    "image = test_ds[0]['pixel_values']\n",
    "gt_seg = test_ds[0]['label']\n",
    "image"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "7m7IfMv6R3_5"
   },
   "source": [
    "To segment this test image, we first need to prepare the image using the feature extractor. Then we'll forward it through the model.\n",
    "\n",
    "We also need to remember to upscale the output logits to the original image size. In order to get the actual category predictions, we just have to apply an `argmax` on the logits."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "8nNSSqEUBS2v"
   },
   "outputs": [],
   "source": [
    "from torch import nn\n",
    "\n",
    "inputs = feature_extractor(images=image, return_tensors=\"pt\")\n",
    "outputs = model(**inputs)\n",
    "logits = outputs.logits  # shape (batch_size, num_labels, height/4, width/4)\n",
    "\n",
    "# First, rescale logits to original image size\n",
    "upsampled_logits = nn.functional.interpolate(\n",
    "    logits,\n",
    "    size=image.size[::-1], # (height, width)\n",
    "    mode='bilinear',\n",
    "    align_corners=False\n",
    ")\n",
    "\n",
    "# Second, apply argmax on the class dimension\n",
    "pred_seg = upsampled_logits.argmax(dim=1)[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "oyHddde_SOgv"
   },
   "source": [
    "Now it's time to display the result. The next cell defines the colors for each category, so that they match the \"category coloring\" on Segments.ai."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "id": "Ky_8gHCRCJHj"
   },
   "outputs": [],
   "source": [
    "#@title `def sidewalk_palette()`\n",
    "\n",
    "def sidewalk_palette():\n",
    "    \"\"\"Sidewalk palette that maps each class to RGB values.\"\"\"\n",
    "    return [\n",
    "        [0, 0, 0],\n",
    "        [216, 82, 24],\n",
    "        [255, 255, 0],\n",
    "        [125, 46, 141],\n",
    "        [118, 171, 47],\n",
    "        [161, 19, 46],\n",
    "        [255, 0, 0],\n",
    "        [0, 128, 128],\n",
    "        [190, 190, 0],\n",
    "        [0, 255, 0],\n",
    "        [0, 0, 255],\n",
    "        [170, 0, 255],\n",
    "        [84, 84, 0],\n",
    "        [84, 170, 0],\n",
    "        [84, 255, 0],\n",
    "        [170, 84, 0],\n",
    "        [170, 170, 0],\n",
    "        [170, 255, 0],\n",
    "        [255, 84, 0],\n",
    "        [255, 170, 0],\n",
    "        [255, 255, 0],\n",
    "        [33, 138, 200],\n",
    "        [0, 170, 127],\n",
    "        [0, 255, 127],\n",
    "        [84, 0, 127],\n",
    "        [84, 84, 127],\n",
    "        [84, 170, 127],\n",
    "        [84, 255, 127],\n",
    "        [170, 0, 127],\n",
    "        [170, 84, 127],\n",
    "        [170, 170, 127],\n",
    "        [170, 255, 127],\n",
    "        [255, 0, 127],\n",
    "        [255, 84, 127],\n",
    "        [255, 170, 127],\n",
    "    ]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "f4BzL0ISSePY"
   },
   "source": [
    "The next function overlays the output segmentation map on the original image."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "G3HqZXyQB7gJ"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "def get_seg_overlay(image, seg):\n",
    "  color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # height, width, 3\n",
    "  palette = np.array(sidewalk_palette())\n",
    "  for label, color in enumerate(palette):\n",
    "      color_seg[seg == label, :] = color\n",
    "\n",
    "  # Show image + mask\n",
    "  img = np.array(image) * 0.5 + color_seg * 0.5\n",
    "  img = img.astype(np.uint8)\n",
    "\n",
    "  return img"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "-yEXFytLSkht"
   },
   "source": [
    "We'll display the result next to the ground-truth mask."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "vnSn2A2U0RMw"
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "pred_img = get_seg_overlay(image, pred_seg)\n",
    "gt_img = get_seg_overlay(image, np.array(gt_seg))\n",
    "\n",
    "f, axs = plt.subplots(1, 2)\n",
    "f.set_figheight(30)\n",
    "f.set_figwidth(50)\n",
    "\n",
    "axs[0].set_title(\"Prediction\", {'fontsize': 40})\n",
    "axs[0].imshow(pred_img)\n",
    "axs[1].set_title(\"Ground truth\", {'fontsize': 40})\n",
    "axs[1].imshow(gt_img)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "r3Chx4bXaCYa"
   },
   "source": [
    "What do you think? Would you send our pizza delivery robot on the road with this segmentation information?\n",
    "\n",
    "The result might not be perfect yet, but we can always expand our dataset to make the model more robust. We can now also go train a larger SegFormer model, and see how it stacks up. If you want to explore further beyond this notebook, here are some things you can try next:\n",
    "\n",
    "* Train the model for longer. \n",
    "* Try out the different segmentation-specific training augmentations from libraries like [`albumentations`](https://albumentations.ai/docs/getting_started/mask_augmentation/). \n",
    "* Try out a larger variant of the SegFormer model family or try an entirely new model family like MobileViT. "
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "machine_shape": "hm",
   "provenance": []
  },
  "gpuClass": "premium",
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
